####################################################
#Author: Kelli Marquardt
#Purpose: Estimate Model with Tau for each ADHD subtype and produce table 9 

# Inputs:
#-  data/est_dat_fake.csv 

# Outputs:
#-  output/tables/tab_9.txt

####################################################

############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

#load packages 
library(dplyr) 
library(fastDummies)
library(tidyr) #for pivot wider

############################
#0 Define number of bootstrap samples 
############################
num_bs=100 #total number of bootstrap samples needed (100 in paper)

############################
#0 Define functions used in estimation 
############################

#demean numeric variables 
demean=function(x){
  if(class(x)=="character"){
    return(x)
  }else{
    h=x-mean(x, na.rm=T)
    return(h)
  }
} 


recover_params_2tau=function(mu, qhat, xbar, bhat, bhat2,  ahat,
                             sigma_lower=1e-8, sigma_upper=1e3){
  
  zq=qnorm(qhat)
  phiq=dnorm(zq)
  k_tilde=(xbar-mu)*qhat/phiq
  
  #rho as a fn of sigma (from bhat and from xbar)
  rho_from_bhat=function(sig){
    return_val=(bhat*sig)/sqrt(1+((bhat*sig)^2))
    return(return_val)
  } 
  rho_from_xbar=function(sig){
    return_val=k_tilde*sqrt(1+(sig^2))/(sig^2)
    return(return_val)
  }
  
  
  #define root function 
  f_root=function(sig) {
    return_val=rho_from_xbar(sig) - rho_from_bhat(sig)
    return(return_val)
  }
  
  #bracket a root for uniroot 
  lo= sigma_lower
  hi=sigma_upper
  flo=f_root(lo)
  fhi=f_root(hi)
  
  # If not bracketed, give error 
  if (!is.finite(flo) || !is.finite(fhi) || flo * fhi > 0) {
    stop("Could not bracket a root for sigma. Try different sigma_lower/sigma_upper.")
  }
  
  
  #recover sigma_hat 
  sig_hat = uniroot(f_root, lower = lo, upper = hi)$root
  
  #recover rho, cbar, tau1, and alpha_hat (note: tau1=tau1, tau2=tau1+alpha_hat)
  rho_hat=rho_from_bhat(sig_hat)
  cbar_hat = mu - zq * sqrt(1 + (sig_hat^2))
  tau_hat = (1 - rho_hat) * mu - ahat * sig_hat * sqrt(1 - (rho_hat^2) ) 
  alpha_hat = -1*bhat2*(sig_hat*sqrt(1-(rho_hat^2)))
  
  return(list(rho = rho_hat, sig = sig_hat, tau = tau_hat, cbar = cbar_hat, alpha=alpha_hat))
}



#################################################
#Step 1: read in data and subset to step1_dat and step2_dat
#################################################

est_dat = read.csv(file.path("..", "data", "est_dat_fake.csv"), stringsAsFactors = FALSE)

step1_dat=est_dat%>%
  select(pat_id, male, hispanic, white, Med, Com, birth_year, 
         first_pcp_id, Qi, xi, 
         ends_with("_uptoQ1"))

step2_dat=est_dat%>%
  select(pat_id, male, Qi, xi, Di, type2_max)

#rename step1_dat and step2_dat to emphasize that these are the unsampled data 
step1_dat_base=step1_dat
step2_dat_base=step2_dat

rm(step1_dat, step2_dat)




#################################################
#Step 2: Prep/clean data as necessary outside of loop
#################################################
#remove the suffix _uptoQ1 here
step1_dat_base=step1_dat_base%>%
  rename_with( ~ sub("_uptoQ1","", .x),
               ends_with("_uptoQ1"))

## define age and age^2, total number of appointments, and birthyear fixed effects
step1_dat_base=step1_dat_base%>%
  rename(age=age_mean)%>%
  mutate(age2=age^2)%>%
  mutate(numapt=year_2014+year_2015+year_2016+year_2017)%>%
  dummy_cols(select_columns = "birth_year")

## subset to keep necessary variables for estimation
step1_dat_base=step1_dat_base%>%
  select(pat_id, first_pcp_id, Qi, male, xi, 
         white, hispanic, numapt, numdocs, age, age2,
         starts_with("year"), starts_with("birth_year_"),
         Med, Com, psych_doc, well, behav)

#################################################
#Step 3: Create a function that takes in a vector of patient IDs and estimates the 2-tau model estimates
#################################################

bootstrap_est=function(pat_id_vec){
  
  ####################
  #start with step 1 to get mu and sig
  ####################
  
  #######
  #take step 1 dat and subset to those with sha_mrn_id in pat_id_vec (but note the duplicates!)
  #remove patients with a first pcp that does not have enough patients (n_gender <2)
  #remove patients with a first pcp that has Q_gender %in% c(0,1)
  step1_dat =step1_dat_base %>%
    filter(pat_id %in% pat_id_vec)%>%
    slice(match(pat_id_vec, pat_id))  # This ensures duplication
  
  step1_dat=step1_dat%>%
    group_by(first_pcp_id)%>%
    mutate(num_male=sum(male),
           num_female=n()-num_male)%>%
    filter(num_male>2 & num_female>2)%>%
    mutate(meanQ_male=sum(Qi*male)/num_male,
           meanQ_female=sum(Qi*(1-male))/num_female)%>%
    filter(!(meanQ_male %in% c(0,1) | meanQ_female %in% c(0,1)))%>%
    ungroup()%>%
    select(-c(num_male, num_female, meanQ_male, meanQ_female))
  
  
  ##demean the relevant first stage variables in a new dataset 
  dat_fs=step1_dat%>%
    filter(if_all(-c(pat_id, first_pcp_id, Qi, male, xi), ~ !is.na(.x))) %>%
    mutate(across(-c(pat_id, first_pcp_id, Qi, male, xi), ~ demean(.x)))
  
  
  #add back in the male and pcp id to create pcp by gender indicator 
  dat_fs=dat_fs%>%
    mutate(pcpg=paste0(first_pcp_id,  "_", male))
  
  #define first stage probit function  
  reg_fs = as.formula(paste(paste("Qi", "~",""),
                            paste(names(dat_fs)[!names(dat_fs) %in%  c("pat_id","Qi","first_pcp_id","male","pcpg", "xi")],
                                  collapse = " + ")," + as.factor(pcpg)-1",
                            collapse=""))
  
  
  ## estimate first stage 
  step1_reg=glm(reg_fs, family = binomial(link = "probit"), data=dat_fs)
  
  #pull out pcp by gender fixed effects and std errors 
  pcp_fe=as.data.frame(coef(summary(step1_reg)))
  pcp_fe$varname=rownames(pcp_fe)
  pcp_fe=pcp_fe%>%
    filter(substr(varname, 1, 15)=="as.factor(pcpg)")%>%
    mutate(pcpg=substr(varname, 16, nchar(varname)))%>%
    select(pcpg, Estimate, `Std. Error`)%>%
    mutate(fe_est=pnorm(Estimate))%>%
    rename(fe_se=`Std. Error`)%>% #rename std error
    select(pcpg, fe_est, fe_se)
  
  #merge in the actual pcp id, and male indicator
  pcp_fe=dat_fs%>%
    select(first_pcp_id, male, pcpg)%>%
    distinct()%>%
    right_join(pcp_fe, by="pcpg")
  
  #also need to get mean of xi (and se) for each pcpg 
  xi_est_dat=step1_dat%>%
    filter(Qi==1 )%>%
    select(pat_id, male, xi, first_pcp_id)%>%
    left_join(pcp_fe, by=c("first_pcp_id", "male"))%>%
    filter(!is.na(fe_est))
  
  xi_est_dat=lm(xi ~ as.factor(pcpg)-1 ,  data=xi_est_dat)
  xi_est_dat=as.data.frame(coef(summary(xi_est_dat)))
  xi_est_dat$varname=rownames(xi_est_dat)
  xi_est_dat=xi_est_dat%>%
    filter(substr(varname, 1, 15)=="as.factor(pcpg)")%>%
    mutate(pcpg=substr(varname, 16, nchar(varname)))%>%
    select(pcpg, Estimate, `Std. Error`)%>%
    mutate(xi_est=Estimate)%>% 
    rename(xi_se=`Std. Error`)%>% 
    select(pcpg, xi_est, xi_se)
  
  
  #merge back in to pcp_fe and define weights to get extrap_dat
  extrap_dat=pcp_fe%>%
    left_join(xi_est_dat, by="pcpg")%>%
    mutate(wt=1/(xi_se))
  
  #get exponential extrapolations 
  nls_hold_m = try(nls(xi_est ~ b0 * exp(b1 * fe_est),
                       data = extrap_dat,
                       subset = (male==1),
                       weights = wt,
                       start = list(b0 = 0.5, b1 = -0.5),
                       control = nls.control(minFactor = 1e-5, maxiter = 100)),
                   silent = TRUE)
  nls_hold_f = try(nls(xi_est ~ b0 * exp(b1 * fe_est),
                       data = extrap_dat,
                       subset = (male==0),
                       weights = wt,
                       start = list(b0 = 0.5, b1 = -0.5),
                       control = nls.control(minFactor = 1e-5, maxiter = 100)),
                   silent = TRUE)
  
  
  
  #check if the first stage does not converge and return 0
  if ((inherits(nls_hold_m, "try-error") | inherits(nls_hold_f,"try-error"))) {
    exp_coef_m=exp_coef_f=0
    mu_m=mu_f=0
    rm(nls_hold_m, nls_hold_f)
  } else {
    
    exp_coef_m=summary(nls_hold_m)$coef
    exp_coef_f=summary(nls_hold_f)$coef
    
    #check if fs slope is negative. If not, return 0. 
    if ((exp_coef_m[2, "Estimate"]>0) | (exp_coef_f[2, "Estimate"]>0)){
      exp_coef_m=exp_coef_f=0
      mu_m=mu_f=0
    } else {
      # Continue without error
      mu_m=exp_coef_m[1, "Estimate"]*exp(exp_coef_m[2, "Estimate"])
      mu_f=exp_coef_f[1, "Estimate"]*exp(exp_coef_f[2, "Estimate"])
      rm(nls_hold_m, nls_hold_f)
    } 
  }
  
  
  ####################
  #now continue to step 2 to get remaining params
  ####################

  #subset to those in pat_id_vec
  step2_dat=step2_dat_base%>%
    filter(pat_id %in% pat_id_vec)%>%
    slice(match(pat_id_vec, pat_id))  # This ensures duplication
  
  
  boot_est=NA
  mu_fm=c(mu_f, mu_m) 
  
  for (m in c(0,1)){
    mu=mu_fm[m+1] # get mu based on m
    
    #if mu is 0, return 0s
    if(max(mu_fm)==0){
      
      mu=sig=cbar=rho=tau=alpha=0
      
    } else {
      
      #get dat_g and dat_condit
      dat_g=step2_dat%>%
        filter(male==m)%>%
        select(Qi,xi, Di, type2_max)
      dat_condit=dat_g%>%
        filter(Qi==1)
      
      #get probit slopes and constant
      probit_output=glm(dat_condit$Di ~ dat_condit$xi+dat_condit$type2_max, family=binomial(link="probit"))
      cons=probit_output$coefficients[1]
      slope=probit_output$coefficients[2]
      slope2=probit_output$coefficients[3]
      
      
      #back out the parameters- need xbar, qhat, sd limit and check that xbar>mu, bhat>0 
      xbar=mean(dat_condit$xi)
      qhat=mean(dat_g$Qi)
      sd_max=sd(dat_condit$xi)*100
      
      
      #get sig, rho, tau, c
      #need to check if sqrt is defined. If not, return 0 for all estimates 
      if(xbar<mu | slope<0){
        mu=sig=cbar=rho=tau=alpha=0
      }else{
        
        param_out=recover_params_2tau(mu=mu, qhat=qhat, xbar=xbar, 
                                      bhat=slope, bhat2=slope2, ahat=cons, 
                                      sigma_upper = sd_max)
        
        
        rho=param_out$rho
        sig=param_out$sig
        tau=param_out$tau
        alpha=param_out$alpha
        cbar=param_out$cbar
        
        
      }
    }
    #save as parm_vec and append to boot_est
    parm_vec=c( m, mu, sig, cbar, rho, alpha, tau, (tau+alpha))
    names(parm_vec)=c("male","mu","sig", "c_bar","rho","alpha", "tau_t1","tau_t2")
    boot_est=rbind(boot_est, parm_vec)
    
  }  #end loop over m in 0,1
  
  boot_est=as.data.frame(boot_est)
  boot_est=boot_est%>%
    na.omit()
  row.names(boot_est)=NULL
  
  
  #return bootstrapped est 
  return(boot_est)
}



#################################################
#Step 4: Do 100 bootstrapped estimates (saving to a dataframe)
#################################################

#####
#4a: first run for baseline sample and hold estimates 
#####
pat_id_base=unique(step2_dat_base$pat_id)
main_est_2tau=bootstrap_est(pat_id_base)

#####
#4b: create bootstrap samples and estimate until num_bs converged values 
#####

#create an empty dataframe to fill in 
bootstrap_dat=NA
attempt=0 # count how many bootstrap samples have been estimated
keep=0 # count how many bootstrap samples were successfully estimated 


while (keep < num_bs){
  
  attempt= attempt+1
  
  #do not print messages or warnings 
  suppressMessages({
    suppressWarnings({
      
      #sample pat_id_base with replacement 
      set.seed(attempt)
      pat_id_bs=sample(pat_id_base, size=length(pat_id_base), replace = T)
      
      #get the estimates for the bootstrapped sample 
      bs_est=bootstrap_est(pat_id_bs)
      
      #check that both estimates were successful 
      if(min(bs_est$mu)>0){
        #if successful, append to bootstrap dat and update keep and iter
        keep=keep+1
        bs_est$iter=rep(keep, 2)
        bootstrap_dat=rbind(bootstrap_dat, bs_est)
      } 
      
      rm(pat_id_bs, bs_est)
      #print(c(attempt, keep))
    })}) #end suppress messages and warnings
  
}

#save as dataframe and remove initialize row
bootstrap_dat=as.data.frame(bootstrap_dat)
bootstrap_dat=bootstrap_dat%>%
  na.omit()


#################################################
#Step 5: Produce the two tau estimate table (table 9)
#################################################


#pivot wider (both bootstrap and main)
bootstrap_dat=pivot_wider(bootstrap_dat, 
                          names_from = male,
                          values_from = c(mu, sig, c_bar, rho, tau_t1, tau_t2, alpha))

main_est=main_est_2tau%>%
  pivot_wider(names_from= male, 
              values_from=c(mu, sig, c_bar, rho, tau_t1, tau_t2, alpha))


#get standard errors 
bootstrap_dat=bootstrap_dat%>%
  summarise_all(sd)

#print table 

est_table_2tau=paste("\\begin{tabular}{lccc}\n",
                     "\\toprule\n",
                     "& Male & Female & Difference \\\\\n",
                     "\\midrule\n",
                     "\\addlinespace \n",
                     sprintf("Pop. Mean Risk $\\mu_\\theta$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$mu_1, main_est$mu_0, main_est$mu_1-main_est$mu_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$mu_1, bootstrap_dat$mu_0),
                     sprintf("Pop. Risk Dispersion $\\sigma_\\theta$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$sig_1, main_est$sig_0, main_est$sig_1-main_est$sig_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$sig_1, bootstrap_dat$sig_0),
                     sprintf("Utilization Costs $c_\\theta$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$c_bar_1, main_est$c_bar_0, main_est$c_bar_1-main_est$c_bar_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$c_bar_1, bootstrap_dat$c_bar_0),
                     sprintf("Signal Quality $\\rho_\\theta$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$rho_1, main_est$rho_0, main_est$rho_1-main_est$rho_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$rho_1, bootstrap_dat$rho_0),
                     sprintf("Diagnostic Threshold, Type 1 $\\tau_{1\\theta}$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$tau_t1_1, main_est$tau_t1_0, main_est$tau_t1_1-main_est$tau_t1_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$tau_t1_1, bootstrap_dat$tau_t1_0),
                     sprintf("Diagnostic Threshold, Type 2 $\\tau_{2\\theta}$ & %.3f & %.3f & \\multirow{2}{*}{%.3f} \\\\\n",
                             main_est$tau_t2_1, main_est$tau_t2_0, main_est$tau_t2_1-main_est$tau_t2_0),
                     sprintf(" & (%.3f) & (%.3f) &  \\\\\n",
                             bootstrap_dat$tau_t2_1, bootstrap_dat$tau_t2_0),
                     "\\bottomrule\n",
                     "\\end{tabular}\n" )

#save table to output 
write(est_table_2tau, file = file.path("..", "output", "tables", "tab_9.txt"))

#END OF SCRIPT

